{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 201,
   "metadata": {},
   "outputs": [],
   "source": [
    "from config import baidu_data_config\n",
    "from helpDataBaidu import load_data\n",
    "from pytorch_pretrained_bert import BertTokenizer\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = baidu_data_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = load_data(opt.train_data_dir)\n",
    "dev_data = load_data(opt.dev_data_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 182,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained('./bert-base-chinese/vocab_new.txt', do_lower_case=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 209,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_unk_origin_char_and_position(word_list, text):\n",
    "    \"\"\"\n",
    "    输入: 茶树茶网蝽，Stephanitis chinensis Drake，属半翅目网蝽科冠网椿属的一种昆虫\n",
    "    输出:\n",
    "    {4:'蝽', ...}表示UNK的位置和其对应字符\n",
    "    \"\"\"\n",
    "    res = {}\n",
    "    # print(word_list)\n",
    "    unk_ps = []\n",
    "    left2word = []\n",
    "    right2word = []\n",
    "    leftnum = []\n",
    "    rightnum = []\n",
    "    for idx, word in enumerate(word_list):\n",
    "        if word == '[UNK]':\n",
    "            unk_ps.append(idx)\n",
    "            leftword = ''\n",
    "            if idx >= 2:\n",
    "                leftword = word_list[idx-2]+word_list[idx-1]\n",
    "                leftword = leftword.replace('[UNK]','.')\n",
    "                leftnum.append(2)\n",
    "            elif idx >= 1:\n",
    "                leftword = word_list[idx-1]\n",
    "                leftword = leftword.replace('[UNK]','.')\n",
    "                leftnum.append(1)\n",
    "            else:\n",
    "                leftnum.append(0)\n",
    "            rightword = ''\n",
    "            if idx <= len(word_list) - 3:\n",
    "                rightword = word_list[idx+1] + word_list[idx+2]\n",
    "                rightword = rightword.replace('[UNK]','.')\n",
    "                rightnum.append(2)\n",
    "            elif idx <= len(word_list) - 2:\n",
    "                rightword = word_list[idx+1]\n",
    "                rightword = rightword.replace('[UNK]','.')\n",
    "                rightnum.append(1)\n",
    "            else:\n",
    "                rightnum.append(0)\n",
    "            left2word.append(leftword)\n",
    "            right2word.append(rightword)\n",
    "    #print(left2word)\n",
    "    # print(right2word)\n",
    "    for idx, p in enumerate(unk_ps):\n",
    "        # print(left2word[idx]+'.'+right2word[idx])\n",
    "        try:\n",
    "            start = re.search(left2word[idx]+'.'+right2word[idx], text.lower(), flags=0).start()\n",
    "            unk_word = text[start+leftnum[idx]]\n",
    "            res[start + leftnum[idx]] = unk_word\n",
    "        except:\n",
    "            pass\n",
    "#             print(left2word)\n",
    "#             print(right2word)\n",
    "#             print(left2word[idx]+'.'+right2word[idx])\n",
    "#             print(text.lower())\n",
    "#             print(word_list)\n",
    "    return res\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
